/*
 * SPDX-FileCopyrightText: Copyright (c) 2023-2024 NVIDIA CORPORATION & AFFILIATES. All rights
 * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement
 *
 * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
 * property and proprietary rights in and to this material, related
 * documentation and any modifications thereto. Any use, reproduction,
 * disclosure or distribution of this material and related documentation
 * without an express license agreement from NVIDIA CORPORATION or
 * its affiliates is strictly prohibited.
 */

#pragma once
#include "ldgsts.cuh"
#include "mha.h"
#include "utils.cuh"

// for beam search
template <typename Head, uint32_t tokensPerPage, uint32_t nbPages>
struct IndexedHeadPtrImpl {
  static_assert(tokensPerPage != 0 && nbPages != 0);
  uint32_t const* indices;  // values are in range [0, beamWidth)
  Head* pool;
  Vec<KVCachePageIndex, nbPages> const* pageIndices;
  uint32_t tokenOffset;   // token offset within the first page
  uint32_t headIdx;       // head index
  uint32_t stride_page;   // stride for each page (in units of Head)
  uint32_t stride_token;  // stride for each token (in units of Head)
  uint32_t stride_head;   // stride for each head (in units of Head)

  __device__ inline Head& operator[](uint32_t i) const { return *(*this + i); }

  __device__ inline Head* operator+(uint32_t i) const {
    uint32_t const beamIdx = indices[i];
    assert(beamIdx < beamWidth);
    uint32_t const absoluteTokenIdx = tokenOffset + i;
    auto const pageIdx = pageIndices[beamIdx][nbPages == 1 ? 0U : absoluteTokenIdx / tokensPerPage];
    return pool + pageIdx * stride_page + (absoluteTokenIdx % tokensPerPage) * stride_token +
           headIdx * stride_head;
  }
};

template <typename Head>
struct IndexedHeadPtrImpl<Head, 0, 0> {
  uint32_t const* indices;  // values are in range [0, beamWidth)
  Head* pointer;
  uint32_t offset;
  uint32_t beamStride;

  __device__ inline Head& operator[](uint32_t i) const { return *(*this + i); }

  __device__ inline Head* operator+(uint32_t i) const {
    assert(indices[i] < beamWidth);
    return pointer + (beamStride * indices[i] + offset + i);
  }
};

template <typename Head, uint32_t tokensPerPage, uint32_t nbPages = 0>
using IndexedHeadPtr = IndexedHeadPtrImpl<Head, tokensPerPage, nbPages>;

// for beamWidth = 1
template <typename Head, uint32_t tokensPerPage, uint32_t nbPages>
struct HeadPtr {
  static_assert(tokensPerPage != 0 && nbPages != 0);
  Head* pool;
  Vec<KVCachePageIndex, nbPages> pageIndices;
  uint32_t tokenOffset;   // token offset within the first page
  uint32_t headIdx;       // head index
  uint32_t stride_page;   // stride for each page (in units of Head)
  uint32_t stride_token;  // stride for each token (in units of Head)
  uint32_t stride_head;   // stride for each head (in units of Head)

  __device__ inline Head& operator[](uint32_t i) const { return *(*this + i); }

  __device__ inline Head* operator+(uint32_t i) const {
    uint32_t const absoluteTokenIdx = tokenOffset + i;
    auto const pageIdx = pageIndices[nbPages == 1 ? 0U : absoluteTokenIdx / tokensPerPage];
    return (pageIdx & (1U << 31))
               ? nullptr
               : pool + pageIdx * stride_page + (absoluteTokenIdx % tokensPerPage) * stride_token +
                     headIdx * stride_head;
  }
};

template <typename Head>
struct HeadPtr<Head, 0, 0> : TinyPtr<Head> {};

// template <typename Head>
// #if BEAM_WIDTH == 1
// using SrcHeadPtr = TinyPtr<Head const>;
// #else
// using SrcHeadPtr = IndexedHeadPtr<Head>;
// #endif

// @fixme: give evict first hint for last part.
template <typename Head, uint32_t maxNbCopiedHeads, uint32_t nbPartsPerHead, bool swizzle,
          bool isFull, uint32_t dstNbHeads, typename SrcHeadPtr,
          typename LocalHeadIdxMap = uint32_t (*)(uint32_t)>
__device__ inline void copyPartialHeadsAsync(
    Warp const& warp,
    Array2D<LdGrain, dstNbHeads, exactDiv(exactDiv(sizeof(Head), nbPartsPerHead), grainBytes)>& dst,
    uint32_t dstHeadOffset, SrcHeadPtr const& src, uint32_t idxPart,
    uint32_t nbAvailHeads = maxNbCopiedHeads,
    LocalHeadIdxMap&& localHeadIdxMap = [](uint32_t x) { return x; }) {
  static_assert(maxNbCopiedHeads <= dstNbHeads);
  assert(idxPart < nbPartsPerHead);
  assert(dstHeadOffset + maxNbCopiedHeads <= dstNbHeads);
  assert(sizeof(Head) * (src.offset + maxNbCopiedHeads) <= (1UL << 32));
  assert(!isFull || nbAvailHeads >= maxNbCopiedHeads);
  constexpr uint32_t headBytes = sizeof(Head);
  constexpr uint32_t partBytes = exactDiv(headBytes, nbPartsPerHead);
  constexpr uint32_t warpLdBytes = partBytes * maxNbCopiedHeads;
  constexpr uint32_t thrdLdBytes = exactDiv(warpLdBytes, warp_size);
  assertIsPowerOf2<thrdLdBytes>();
  static_assert(thrdLdBytes >= grainBytes);
  // a segment is responsible for loading one partial head collaboratively
  constexpr uint32_t thrdsPerSeg = exactDiv(partBytes, grainBytes);
  static_assert(thrdsPerSeg > 0 && thrdsPerSeg <= warp_size);
  assertIsPowerOf2<thrdsPerSeg>();
  assert(__shfl_sync(0xFU << (laneId() / 4 * 4), src.offset, 0, 4) == src.offset);
  auto const warpLane = laneId();
  uint32_t const segIdx = warpLane / thrdsPerSeg;
  uint32_t const segLane = warpLane % thrdsPerSeg;
  constexpr uint32_t partsPerWarpInst = exactDiv(grainBytes * warp_size, partBytes);
#pragma unroll
  for (uint32_t i = 0; i < thrdLdBytes / grainBytes; i++) {
    uint32_t const idxHeadLocal = partsPerWarpInst * i + segIdx;
    assert(idxHeadLocal < maxNbCopiedHeads);
    bool const isHeadInBound = isFull || (idxHeadLocal < nbAvailHeads);
    constexpr uint32_t grainsPerPart = exactDiv(partBytes, grainBytes);
    using SrcHead = mha::decay_t<decltype(src[0])>;
    constexpr uint32_t nbValidGrains = exactDiv(sizeof(SrcHead), grainBytes);
    uint32_t const idxGrainInsideHead = grainsPerPart * idxPart + segLane;
    bool const isGrainInBound = (!isHeadPadded || idxGrainInsideHead < nbValidGrains);
    SrcHead const* const pSrcHead = src + localHeadIdxMap(idxHeadLocal);
    bool const isValidPage = (pSrcHead != nullptr);
    LdGrain const* const pSrc = reinterpret_cast<LdGrain const*>(pSrcHead) + idxGrainInsideHead;
    LdGrain* const pDst = &dst.template at<swizzle>(dstHeadOffset + idxHeadLocal, segLane);
    assert(!hasBankConflict(pDst));
    ldgsts::copyAsync<grainBytes>(pDst, pSrc,
                                  isValidPage && isHeadInBound && isGrainInBound ? grainBytes : 0u);
  }
}

template <typename Head, uint32_t maxNbCopiedHeads, uint32_t nbWarps, bool swizzle, bool isFull,
          uint32_t dstNbHeads, typename SrcHeadPtr,
          typename LocalHeadIdxMap = uint32_t (*)(uint32_t)>
__device__ inline void copyHeadsAsync(
    uint32_t idxWarp, Array2D<LdGrain, dstNbHeads, exactDiv(sizeof(Head), grainBytes)>& dst,
    SrcHeadPtr const& src, uint32_t nbAvailHeads = maxNbCopiedHeads,
    LocalHeadIdxMap&& localHeadIdxMap = [](uint32_t x) { return x; }) {
  assert(idxWarp < nbWarps);
  Warp const& warp = this_warp();
  constexpr uint32_t maxNbHeadsPerWarp = exactDiv(maxNbCopiedHeads, nbWarps);
  uint32_t const dstHeadOffset = maxNbHeadsPerWarp * idxWarp;
  uint32_t const warpNbAvailHeads =
      (dstHeadOffset < nbAvailHeads ? nbAvailHeads - dstHeadOffset : 0);
  constexpr uint32_t idxPart = 0;
  copyPartialHeadsAsync<Head, maxNbHeadsPerWarp, 1, swizzle, isFull, dstNbHeads>(
      warp, dst, dstHeadOffset, src, idxPart, warpNbAvailHeads,
      [&](uint32_t x) { return localHeadIdxMap(dstHeadOffset + x); });
}

template <bool isAsync, uint32_t maxTotalNbGrains, uint32_t nbWarps, bool isFull = true>
__device__ inline void copyGrains(uint32_t idxWarp, LdGrain* dst, LdGrain const* src,
                                  uint32_t totalNbGrains = maxTotalNbGrains) {
  assert((isFull && totalNbGrains == maxTotalNbGrains) ||
         (!isFull && totalNbGrains <= maxTotalNbGrains));
  constexpr uint32_t nbThrds = warp_size * nbWarps;
  uint32_t const tid = warp_size * idxWarp + laneId();
// copy output to scratch
#pragma unroll
  for (uint32_t i = 0; i < divUp(maxTotalNbGrains, nbThrds); i++) {
    uint32_t const idx = nbThrds * i + tid;
    if (!(isFull && maxTotalNbGrains % nbThrds == 0) && idx >= totalNbGrains) {
      break;
    }
    if constexpr (isAsync) {
      ldgsts::copyAsync<grainBytes>(&dst[idx], &src[idx], grainBytes);
    } else {
      dst[idx] = src[idx];
    }
  }
}

// with ldmatrix, what we load for fp8 cache is T0:{e0,e1,e2,e3}; T1:{e4, e5, e6, e7};
// T2:{e8,e9,e10,e11}; T3:{e12, e13, e14, e15}; When casted to fp16, it will be T0:{e0, e1}; T1{e4,
// e5};...  | T0:{e2, e3}; T1{e6, e7}; ... We need to reorder Q to match that order. isFwd=false to
// revert the reorder.
template <uint32_t nbWarps, bool swizzled, bool isFwd, uint32_t cols, uint32_t rows>
__device__ inline void reorder16bQHeadsToMatch8bKCache(uint32_t idxWarp,
                                                       Array2D<LdGrain, rows, cols>& qHeads) {
  assert(idxWarp < nbWarps);
  constexpr uint32_t nbWarpIters = exactDiv(exactDiv(cols, 2) * rows, warp_size);  // warps * iters
  constexpr uint32_t nbWorkingWarps = mha::min(nbWarps, nbWarpIters);
  if (idxWarp >= nbWorkingWarps) {
    return;
  }
  static_assert(cols % 2 == 0);
  uint32_t const tid = warp_size * idxWarp + laneId();
  constexpr uint32_t iterCols = exactDiv(warp_size * nbWorkingWarps, rows) * 2;
  static_assert(cols % iterCols == 0,
                "fix this by reducing nbWorkingWarps, or use divUp and add runtime check");
  constexpr uint32_t nbIters = exactDiv(cols, iterCols);
  static_assert(nbIters == exactDiv(nbWarpIters, nbWorkingWarps));
  uint32_t const r = tid % rows;
  uint32_t const cInit = tid / rows * 2;
#pragma unroll
  for (uint32_t n = 0; n < nbIters; n++) {
    uint32_t const c = cInit + iterCols * n;
    LdGrain const src[2] = {
        qHeads.template at<swizzled>(r, c),
        qHeads.template at<swizzled>(r, c + 1),
    };
    auto const& s = reinterpret_cast<Vec<uint32_t, LdGrain::size * 2> const&>(src);
    if constexpr (isFwd) {
      qHeads.template at<swizzled>(r, c) = LdGrain{s[0], s[2], s[4], s[6]};
      qHeads.template at<swizzled>(r, c + 1) = LdGrain{s[1], s[3], s[5], s[7]};
    } else {
      qHeads.template at<swizzled>(r, c) = LdGrain{s[0], s[4], s[1], s[5]};
      qHeads.template at<swizzled>(r, c + 1) = LdGrain{s[2], s[6], s[3], s[7]};
    }
  }
}

template <bool usePagedKVCache>
struct KVCacheList;

template <>
struct KVCacheList<true> {
  GMemCacheHead* kCacheVLLM;
  GMemCacheHead* vCacheVLLM;
  KVCachePageIndex const*
      kvCachePageList;  // shape: KVCachePageIndex[batchSize][beamWidth][2][maxNbPagesPerSeq].
  SeqLenDataType const* seqLenList;  // shape: [batchSize][beamWidth] (for compatibility)
  uint32_t maxNbPagesPerSeq;
};

template <>
struct KVCacheList<false> {
  GMemKVCacheHead* data;  // shape: KVCacheHead[batchSize][beamWidth][2][nbKHeads][capacity]
  SeqLenDataType const* seqLenList;  // shape: [batchSize][beamWidth] (for compatibility)
  uint32_t capacity;
};

__device__ inline uint32_t getSeqLen(uint32_t const* seqLenList, uint32_t idxReq) {
  uint64_t cachePolicy;
  asm("createpolicy.fractional.L2::evict_last.b64 %0;\n" : "=l"(cachePolicy));
  uint32_t len;
  asm("ld.global.nc.L1::evict_last.L2::cache_hint.L2::256B.b32 %0, [%1], %2;\n"
      : "=r"(len)
      : "l"(&seqLenList[idxReq * beamWidth]), "l"(cachePolicy));
  for (uint32_t i = 0; i < beamWidth; i++) {
    assert(len == seqLenList[idxReq * beamWidth + i]);
  }
  return len;
}

template <bool isPaged>
__device__ inline uint32_t getCacheSeqLen(KVCacheList<isPaged> const& cacheList, uint32_t idxReq) {
  return getSeqLen(cacheList.seqLenList, idxReq);
}

__device__ inline uint32_t getCtxCacheSeqLen(BeamSearchParams const& beamSearchParams,
                                             uint32_t idxReq) {
  return getSeqLen(beamSearchParams.ctxLenList, idxReq);
}

template <uint32_t nbLoadedPages>
__device__ inline Vec<KVCachePageIndex, nbLoadedPages> getPage(KVCacheList<true> const& cacheList,
                                                               bool isK, uint32_t idxReq,
                                                               uint32_t idxBeam,
                                                               uint32_t idxPageBeg,
                                                               uint32_t nbPages) {
  auto const maxNbPagesPerSeq = cacheList.maxNbPagesPerSeq;
  Vec<KVCachePageIndex, nbLoadedPages> ret;
#pragma unroll
  for (uint32_t i = 0; i < nbLoadedPages; i++) {
    uint32_t const idxPage = idxPageBeg + i;
    ret[i] = (idxPage < nbPages ? cacheList.kvCachePageList[maxNbPagesPerSeq * idxReq + idxPage]
                                : kBAD_PAGE_INDEX);
  }
  return ret;
}

template <uint32_t nbWarps, uint32_t nbLoadedPages>
__device__ inline void loadPagesForBeamSearchAsync(
    uint32_t idxWarp, Vec<Vec<KVCachePageIndex, nbLoadedPages>, beamWidth>& dst,
    KVCacheList<true> const& cacheList, bool isK, uint32_t idxReq, uint32_t idxPageBeg,
    uint32_t nbPages) {
  assert(idxWarp < nbWarps);
  auto const maxNbPagesPerSeq = cacheList.maxNbPagesPerSeq;
  static_assert(beamWidth < warp_size);
  auto const tid = warp_size * idxWarp + laneId();
  auto const idxBeam = tid / nbLoadedPages;
  auto const idxLoadedPage = tid % nbLoadedPages;
  static_assert(warp_size * nbWarps >= beamWidth * nbLoadedPages);
  if (idxBeam < beamWidth) {
    constexpr uint32_t nbBytes = sizeof(KVCachePageIndex);
    uint32_t const idxPage = idxPageBeg + idxLoadedPage;
    ldgsts::copyAsync<nbBytes>(
        &dst[idxBeam][idxLoadedPage],
        &cacheList.kvCachePageList[beamWidth * 2 * maxNbPagesPerSeq * idxReq +
                                   2 * maxNbPagesPerSeq * idxBeam + (isK ? 0U : maxNbPagesPerSeq) +
                                   idxPage],
        idxPage < nbPages ? nbBytes : 0U);
  }
}

template <uint32_t nbWarps, uint32_t length, bool isFullTile = false>
__device__ inline void loadIndicesForBeamSearchAsync(uint32_t idxWarp, Vec<uint32_t, length>& dst,
                                                     BeamSearchParams const& params,
                                                     uint32_t idxReq, uint32_t idxBeam,
                                                     uint32_t uniformSeqOffset, uint32_t seqLen) {
  constexpr uint32_t nbThreads = warp_size * nbWarps;
  // constexpr uint32_t indicesPerInst = mha::min(exactDiv(grainBytes, sizeof(uint32_t)),
  // divUp(length, nbThreads));
  // // @fixme: std::bit_ceil on length
  constexpr uint32_t indicesPerInst = 1U;  // to handle unaligned case.
  constexpr uint32_t bytesPerInst = sizeof(uint32_t) * indicesPerInst;
  assertIsPowerOf2<indicesPerInst>();
  uint32_t const capacity = params.capacity;
  uint32_t const srcOffset = (idxReq * beamWidth + idxBeam) * capacity + uniformSeqOffset;
  uint32_t const tid = warp_size * idxWarp + laneId();
  constexpr uint32_t indicesPerIter = indicesPerInst * nbThreads;
#pragma unroll
  for (uint32_t i = 0; i < length / indicesPerIter; i++) {
    uint32_t const idx = indicesPerIter * i + indicesPerInst * tid;
    ldgsts::copyAsync<bytesPerInst>(
        &dst[idx], &params.indices[srcOffset + idx],
        (isFullTile || uniformSeqOffset + idx < seqLen) ? bytesPerInst : 0);
  }
  if constexpr (length % indicesPerIter != 0) {
    uint32_t const idx = indicesPerIter * (length / indicesPerIter) + indicesPerInst * tid;
    if (idx < length) {
      ldgsts::copyAsync<bytesPerInst>(
          &dst[idx], &params.indices[srcOffset + idx],
          (isFullTile || uniformSeqOffset + idx < seqLen) ? bytesPerInst : 0);
    }
  }
}

__device__ inline InputElem2 float2ToInputElem2(float2 src) {
  InputElem2 dst;
  if constexpr (mha::is_same_v<InputElem2, half2>) {
    reinterpret_cast<half2&>(dst) = __float22half2_rn(src);
    return dst;
  } else if constexpr (mha::is_same_v<InputElem2, nv_bfloat162>) {
    reinterpret_cast<nv_bfloat162&>(dst) = __float22bfloat162_rn(src);
    return dst;
  } else if constexpr (mha::is_same_v<InputElem2, __nv_fp8x2_e4m3>) {
    reinterpret_cast<__nv_fp8x2_e4m3&>(dst) = __nv_fp8x2_e4m3{src};
    return dst;
  } else {
    trap();
  }
}

template <bool real>
using TokenOrNone = RealTypeOrNone<real, CtaBarrier::arrival_token>;

template <bool real>
__device__ inline TokenOrNone<real> arrive(CtaBarrier* pBarrier) {
  if constexpr (real) {
    return pBarrier->arrive();
  } else {
    assert(pBarrier == nullptr);
    return None{};
  }
}

template <bool real>
__device__ inline void wait(CtaBarrier* pBarrier, TokenOrNone<real>&& token) {
  if constexpr (real) {
    pBarrier->wait(mha::move(token));
  } else {
    assert(pBarrier == nullptr);
    __syncwarp();
  }
}

template <bool real>
__device__ inline bool test_wait(CtaBarrier* pBarrier, TokenOrNone<real>&& token) {
  if constexpr (real) {
    uint32_t complete;
    asm volatile(
        "{\n"
        ".reg .pred       complete;\n"
        "mbarrier.test_wait.acquire.cta.shared::cta.b64 complete, [%1], %2;\n"
        "selp.b32 %0, 1, 0, complete;\n}\n"
        : "=r"(complete)
        : "l"(__cvta_generic_to_shared(pBarrier)), "l"(token));
    return bool(complete);
  } else {
    return false;
  }
}

template <bool real>
using ParityOrNone = RealTypeOrNone<real, bool>;

template <bool real>
__device__ inline void wait_parity(CtaBarrier* pBarrier, ParityOrNone<real> parity) {
  assert(real == (pBarrier != nullptr));
  if constexpr (real) {
    pBarrier->wait_parity(parity);
  } else {
    __syncwarp();
  }
}

template <bool real>
__device__ inline bool test_wait_parity(CtaBarrier* pBarrier, ParityOrNone<real> parity) {
  assert(real == (pBarrier != nullptr));
  if constexpr (real) {
#if USE_CUSTOM_BARRIER
    return pBarrier->test_wait_parity(parity);
#else
    return pBarrier->try_wait_parity_for(parity, cuda::std::chrono::nanoseconds(0));
#endif
  } else {
    return false;
  }
}

template <bool real = true>
__device__ inline ParityOrNone<real>& flip(ParityOrNone<real>& flip) {
  if constexpr (real) {
    flip = !flip;
  }
  return flip;
}

template <bool real = true>
__device__ inline ParityOrNone<real> getAndFlip(ParityOrNone<real>& flag) {
  ParityOrNone<real> const ret = flag;
  if constexpr (real) {
    flag = !flag;
  }
  return ret;
}
